function [SummaryTableDependency, TrialXTrialTable, SummaryTable] = VariableTrialsScriptOneExpResults(resultFile)
% This function will collect data from a single Go/No-Go experiment and create two
% output tables. Details are written below on the organization of the tables. One in which each line will represent a given trial. The
% second will summarize the data by trial type averaging or summing the
% data types. Note, that by default it assumes that there are max of 4
% period types in the experiment. 
% Code created by Lukasz Piszczek, using parts of previous code from Elan
% and Manuel Pasieka.
% Version 2.2. Last update 23/03/2016
% Use migration script to bring data to newest format

%Load data
data = load(resultFile);

% Parameters


data.Time = data.sensorData(:,1)';
if iscell(data.periodsExecuted)
    data.periodsExecuted = cell2mat(data.periodsExecuted);
end
NTrials = data.periodsExecuted(end, 1);
NTrialTypes = length(unique(data.periodsExecuted(:,2)));
MinLagTime = 0.5;
%data.periodsExecuted(:,3) = data.periodsExecuted(:,3) - 1; % clear the premature and merge with precue
%prematureIDx = data.periodsExecuted(:,3) == 0; 
%data.periodsExecuted(data.periodsExecuted(:,3) == 0,3) = 1;
nPeriods = max(data.periodsExecuted(:,3));
% make idx of very first period starting with 1
data.periodsExecuted(1,4) = 1;
CueIDx = 2;
RewardIDx = 3;
ITIIdx = 4;

% This iteration of protocol will assume presence of only one port, thus
% Reward port = Sensory port

[Path, FileID, Extension] = fileparts(resultFile);
FileIDCell = strsplit(FileID,'_');

% extract the Trials name - THIS NEEDS TO BE CORRECTED
TrialName = cell(1,NTrialTypes);
  for i = 1: NTrialTypes  
      TrialName(i) = cellstr(data.protocol.TrialsDefinition{1, i}.name);
  end

% start basic functions
indices = expandTrialIndices(data);
times = expandTrialTimes(data);
trialIDs = getTrialIDs(data)';

% merge precue and premature
if nPeriods > 4
indices(:, 2) = indices(:, 4);
indices(:,3:4) = [];
times(:,1) = times(:,1) + times(:,2);
times(:,2) = [];
nPeriods = nPeriods - 1;
end

%Calculate Pokes and licks per period
PokesPerPeriodTable = CalculateEvents(data.sensorData(:,9)', 2);
LicksPerPeriodTable = CalculateEvents(data.sensorData(:,8)', 2);

% calculate pokes per second for Precue and ITI
PokesPerSecondTable(:,1) = PokesPerPeriodTable(:,1)./times(:,1); % for precue (precue response rate)
PokesPerSecondTable(:,2) = PokesPerPeriodTable(:,4)./times(:,4); % for ITI

% calculate Premature pokes
PrematureResponses = zeros(NTrials,1);
   for yy = 1:NTrials
            if indices(yy,CueIDx+1) == 0
            PrematureResponses(yy,1) = 1;
            end
   end

% calculate latencies in 2 ways
LatencyToPokeTable = LatencyToPoke(times, indices);

% calculate collected rewards
CollectedRewards = zeros(NTrials,1);
   for zz = 1:NTrials
            if indices(zz, RewardIDx+2) > 0
                if (LicksPerPeriodTable(zz,RewardIDx) >0 | LicksPerPeriodTable(zz,ITIIdx) >0)
                   CollectedRewards(zz,1) = 1;
                end
            else
                CollectedRewards(zz,1) = 0;    
            end
   end

% calculate lags   
DelayTable = CalculateDelay(data.sensorData(:,1)', MinLagTime);

% Generate Trial x Trial Matrix
TrialNumber = (1:NTrials)'; % create a column for each consecutive trial
TrialXTrialMatrix = horzcat(TrialNumber, trialIDs,PrematureResponses, PokesPerPeriodTable,LicksPerPeriodTable,CollectedRewards,PokesPerSecondTable,LatencyToPokeTable, DelayTable);

%Create table with file ID for future collection
TrialFileID = cell(NTrials,5);
TrialFileID(:,1) = {FileID};
for u = 1:4
    TrialFileID(:,u+1) = {FileIDCell{u}};
end

%Make Trial by Trial Table - Output Table
VariableNames = {'TrialNo', 'TrialType', 'PrematurePokes', 'PokesPrecue','PokesCue', 'PokesReward', 'PokesITI', 'LicksPrecue','LicksCue', 'LicksReward', 'LicksITI', 'CollectedRewards', 'PrecueResponseRate', 'ITIResponseRate','LatencyToPoke', 'LatencyInPokedTrials', 'NoOfLags'};
TrialXTrialTable = [cell2table(TrialFileID, 'VariableNames',{'FileID', 'Date', 'Phase', 'SessionNo', 'Animal'}) array2table(TrialXTrialMatrix, 'VariableNames', VariableNames)];

% Summarize Trial x Trial Table, Create the output table
SummaryTable = SummarizeTrials(TrialXTrialMatrix);

% substitute numbers for proper names
[TrialXTrialTable, SummaryTable]= ChangeTrialNames(TrialXTrialTable, SummaryTable);

% change to categories (less space, easier to handle)
for TableRows = 1:7
    TrialXTrialTable.(TableRows) = categorical(TrialXTrialTable.(TableRows));
    SummaryTable.(TableRows) = categorical(SummaryTable.(TableRows));
end

% Calculate trial dependencies
% Do this only if there are 2 trial types (Go and No-Go)!!
if NTrialTypes == 2
    %Add the last column adding a marker for Go-Go
    [TrialDependencies, SummaryTableDependency] = CalculateDependency(TrialXTrialMatrix);
    TrialXTrialTable = [TrialXTrialTable TrialDependencies  array2table(times, 'VariableNames',{'Precue_total_time', 'Cue_total_time', 'Reward_total_time', 'ITI_total_time'})];
    % change to categories (less space, easier to handle)
        for TableRows = 23:24
         TrialXTrialTable.(TableRows) = categorical(TrialXTrialTable.(TableRows));
        end
        for TableRows = 1:7
         SummaryTableDependency.(TableRows) = categorical(SummaryTableDependency.(TableRows));  
        end
    SummaryTableDependency.(26) = categorical(SummaryTableDependency.(26));
else
    SummaryTableDependency = table;
end

%Subfunctions
    function indices = expandTrialIndices(data)
        % Loads G/NG Experiment result files and extracts the time indices of
        % each period for each presentation.
        % By Manuel Pasieka
        % Example
        % data = load('1AC1_GO_1.mat')
        %
        % data.TrialTimeIndices
        % {
        %  [1,1] =
        %  {
        %     [1,1] =  1
        %     [1,2] =  289
        %     [1,3] =  298
        %     [1,4] =  334
        %  }
        %  [1,2] =
        %  {
        %     [1,1] =  482
        %     [1,2] = [](0x0)
        %     [1,3] = [](0x0)
        %     [1,4] =  709
        %  }
        %
        % indices = expandTrialIndices( data )
        % ans = 
        %    1     288     289     297     298     333     334     481
        %  482     708       0       0       0       0     709     848
        %  849    1061       0       0       0       0    1062    1204
        % 1205    1525    1526    1535    1536    1578    1579    1725
        % 1726    1957    1958    1988    1989    2031    2032    2174
        % 2175    2301       0       0       0       0    2302    2446
        % 2447    2755    2756    2766    2767    2808    2809    2956

        % Note:
        % It is assumed there are 4 presentations as described in the list data.TrialTimeIndices
        % For skipped periods two zeros (start and stop) are inserted.                   
            % Generate a matrix, containing start and end index for each Period
            % Not used periods are set to zero

    
    try        
        indices = zeros(NTrials, nPeriods*2);
        TrialCnt = 1;
        prevPeriod = 0;
        for i=1:length(data.periodsExecuted)
            period = data.periodsExecuted(i, 3);
            if period < prevPeriod
                TrialCnt = TrialCnt + 1;               
            end
            prevPeriod = period;
            indices(TrialCnt, period*2 - 1) = data.periodsExecuted(i, 4);
            indices(TrialCnt, period*2) = data.periodsExecuted(i, 5);
        end
        %indices(:,[3 4]) = [];
    catch err
        error('Cought Error %s', err.message)
    end


end

function times = expandTrialTimes(data)
% Creates a Table with total time for each period; where each row
% represents consecutive trials, and each column is a given period.
%     Example
%    times =
%    19.0217    9.5562         0    9.8884 <-trial 1, no reward period
%    16.9976         0         0    9.8934 <-trial 2, premature poke, so no
%    cue and reward period
%    16.0574    9.5676    2.8699    9.9139 ...
    times = zeros(NTrials, nPeriods);
    for i=1:NTrials
        for j=1:nPeriods
            if(indices(i, 2*j) == 0)
                times(i, j) = 0;
            else
                times(i, j) = data.Time(indices(i, 2*j)) - data.Time(indices(i, 2*j - 1));
            end
        end
    end
end

    function trialIDs = getTrialIDs(data)
% Get the consecutive trials ID (was it a go, no-go, etc.)
    trialIDs = cell2mat(data.protocol.TrialSequence);
    trialIDs = trialIDs(1:NTrials);

    end


    function cnt = countEvent( SensoryPortEventData, periodIndices, countType )
            % Cound for each period in a series of Sensor Data the accurence of
            % ones, zeros, rising, or falling
            % countType:
            % 0 ... count zeros
            % 1 ... count ones
            % 2 ... count 01 transitions
            % 3 ... count 10 transitions
            %
            % Example
            % #######
            % EventData = [ 0 0 1 1 0 0 0 1 1 1 0 0 ]
            % PeriodIndices = [ 1 5; 6 12 ]
            %
            % countEvent( EventData, PeriodIndices, 0)
            % ans = 
            % 3
            % 4
            %
            % countEvent( EventData, PeriodIndices, 1)
            % ans = 
            % 2
            % 3
            %
            % countEvent( EventData, PeriodIndices, 3)
            % ans = 
            % 1
            % 0
            %
            s = size(periodIndices);
            assert( s(1) > 0, 'periodIndices empty!')
            assert( s(2) == 2, 'periodIndices bust have two columns!')
            assert( countType >= 0 && countType < 4, 'Wrong Count type!')

            cnt = zeros(NTrials, 1);

            %Transform the input data so the decired count value is 1 all the rest 0
            if( countType == 0 ) 
                %Count Zeros
                inputData = ~SensoryPortEventData;
            elseif( countType == 1 )
                %Count Ones
                inputData = SensoryPortEventData;
            elseif( countType == 2 )
                % Count 01 Transitions
                t1 = [ 0 diff(SensoryPortEventData) ];
                inputData = t1 == 1;
            elseif( countType == 3 )
                % Count 10 Transitions
                t1 = [ 0 diff(SensoryPortEventData) ];
                inputData = t1 < 0;
            end

            for i = 1:NTrials
                if( periodIndices(i,1) == 0)
                    cnt(i) = NaN; 
                else
                    cnt(i) = sum( inputData(periodIndices(i,1):periodIndices(i,2)) );
                end
            end
    end
    
    function EventsTable = CalculateEvents(Data, countEventType)
      kk = 1;
          for i = 1:2:2*nPeriods
            EventsTable(:,kk) = countEvent(Data, indices(:,i:i+1),countEventType);
            kk = kk+1;
          end
    end

    function LatencyToPokeTable = LatencyToPoke(times, indices)
            % Extract the latency to poke
            
            % First Column: Take the duration of Cue period as calculated in expandTrialTimes function. If no Cue period was present due to
            % premature response then NaN.
            LatencyToPokeTable = zeros(NTrials, 2);

            for i = 1:NTrials
                if(indices(i,CueIDx+1) == 0)
                    LatencyToPokeTable(i,1) = NaN; 
                else
                    LatencyToPokeTable(i,1) = times (i,CueIDx);
                end

            end
            
            % Second column - take only poked trials, rest NaN
            
            for i = 1:NTrials
             if (PokesPerPeriodTable(i,2) > 0)
                LatencyToPokeTable(i,2) = times (i,CueIDx);
             else
                 LatencyToPokeTable(i,2) = NaN;
             end
         end
    end

    function DelayTable = CalculateDelay(TimeVector, MinLagTime)
        DelayIdx = 0;
        DelayTable = zeros(NTrials,1);
        DelayList = zeros(data.nLoopIterations,1);
        
        % calculate lags
        for j = 1: data.nLoopIterations-1
            DelayList(j)= TimeVector(j+1) - TimeVector(j);
        end
        
        % go through each trial and find how many lags larger than
        % MinLagTime there are, and take the lenght of the vector
        for jj = 1:NTrials
            DelayIdx = find(DelayList(indices(jj,1):indices (jj, 8))>MinLagTime);
            DelayTable(jj,1) = length(DelayIdx);
        end
    end

    function SummaryTable = SummarizeTrials(TrialXTrialMatrix)
        NoColumns = 19;
        SummaryTable = zeros(NTrialTypes+1, NoColumns); % allocate output table
        %calculate sum and mean for each trial type
        % this should be rewritten for table format!!
     if isempty(TrialXTrialMatrix);
       SummaryTable(NTrialTypes+1,NoColumns+1)= NaN;
       SummaryTable(:,2) = -1;
       VariableNamesSummary = [VariableNames, 'NoOfTrials', 'PercentPokedTrialsWithoutPremature','PercentOfPokedTrials'];
       SummaryTable = [cell2table(TrialFileID(1:NTrialTypes+1,:), 'VariableNames',{'FileID', 'Date', 'Phase', 'SessionNo', 'Animal'}) array2table(SummaryTable, 'VariableNames', VariableNamesSummary)];
       SummaryTable.Properties.VariableNames{6} = 'Group';  
     else
        for i = 1:NTrialTypes 
            TempTable = zeros (size([TrialXTrialMatrix(:,2)] == i,2), size(TrialXTrialMatrix,2));
            TempTable = TrialXTrialMatrix([TrialXTrialMatrix(:,2)] == i,:);
            if isempty(TempTable);
                SummaryTable(i,2) = i;
                SummaryTable(i,3:end)= NaN;
            else
            SummaryTable(i,2) = TempTable(1,2);
            SummaryTable(i,3:12) = nansum(TempTable(:,3:12));
            SummaryTable(i,13:16) = nanmean(TempTable(:,13:16));
            SummaryTable(i,17) = nansum(TempTable(:,17));
            SummaryTable(i,18) = size(TempTable, 1);
            SummaryTable(i,19) = (SummaryTable(i,5)/(SummaryTable(i,18)-SummaryTable(i,3)))*100;
            SummaryTable(i,20) = (SummaryTable(i,5)/SummaryTable(i,18))*100;
            end
        end

        
        % calculate sum and mean for all trials together
            SummaryTable(NTrialTypes+1,3:12) = nansum(TrialXTrialMatrix(:,3:12));
            SummaryTable(NTrialTypes+1,13:16) = nanmean(TrialXTrialMatrix(:,13:16));
            SummaryTable(NTrialTypes+1,17) = nansum(TrialXTrialMatrix(:,17));
            SummaryTable(NTrialTypes+1,18) = nansum(SummaryTable(1:NTrialTypes,18));
            SummaryTable(NTrialTypes+1,19) = nanmean(SummaryTable(1:NTrialTypes,19));
            SummaryTable(NTrialTypes+1,20) = nanmean(SummaryTable(1:NTrialTypes,20));
            
       VariableNamesSummary = [VariableNames, 'NoOfTrials', 'PercentPokedTrialsWithoutPremature','PercentOfPokedTrials'];
       SummaryTable = [cell2table(TrialFileID(1:NTrialTypes+1,:), 'VariableNames',{'FileID', 'Date', 'Phase', 'SessionNo', 'Animal'}) array2table(SummaryTable, 'VariableNames', VariableNamesSummary)];
       SummaryTable.Properties.VariableNames{6} = 'Group';       
     end  
       
    % end function
    end

    function [TrialXTrialTable, SummaryTable]= ChangeTrialNames(TrialXTrialTable, SummaryTable)
 %substitute the numeric value for trial ID with the name given in
 %the protocol description
          
        %preallocate tables
         TrialsListNames = cell(NTrials,1);
         TrialsListNamesSummary = cell(NTrialTypes+1,1);
         
        % extract full names row by row
         for ii = 1:NTrialTypes
            TrialsListNames(TrialXTrialTable.TrialType == ii,1) = TrialName(ii);
            TrialsListNamesSummary(SummaryTable.TrialType == ii,1) = TrialName(ii);
         end
         
         TrialXTrialTable.TrialType = [TrialsListNames(:,1)];
         TrialsListNamesSummary{NTrialTypes+1,1} = 'Total';
         SummaryTable.TrialType = [TrialsListNamesSummary(:,1)];
    %end function
    end 

    function [TrialDependencies,SummaryTableDependency] = CalculateDependency(TrialXTrialMatrix)
    
    TrialDependencies = cell(NTrials,2);
    TrialDependencies(1,1:2) = {'None'};
    
    % only done for 2 trial types
    if NTrialTypes == 2;
           
            for a = 1:NTrials-1

                if TrialXTrialMatrix(a,2) == 1
                   TrialDependencies{a+1,1} = {'PreviousGo'};
                elseif TrialXTrialMatrix(a,2) == 2
                   TrialDependencies{a+1,1} = {'PreviousNoGo'};
                end

                if indices(a, 5) > 0
                   TrialDependencies{a+1,2} = {'PreviousSuccess'};
                else
                   TrialDependencies{a+1,2} = {'PreviousFailed'};
                end

            end %end loop



        % make the output as a table
        TrialDependencies = cell2table(TrialDependencies, 'VariableNames', {'PreviousTrialType','PreviousTrialCueResponse'});

        %change the variable to categories
        for rows = 1:2
        TrialDependencies.(rows)  = categorical(TrialDependencies.(rows));
        end

        % Filter the table according to success or failure and summarize results
        TablePreviousTrialType = [SummarizeTrials(TrialXTrialMatrix(TrialDependencies.(1) == 'PreviousGo',:)); SummarizeTrials(TrialXTrialMatrix(TrialDependencies.(1) == 'PreviousNoGo',:))];
        TablePreviousSuccesses = [SummarizeTrials(TrialXTrialMatrix(TrialDependencies.(2) == 'PreviousSuccess',:)); SummarizeTrials(TrialXTrialMatrix(TrialDependencies.(2) == 'PreviousFailed',:))];

        % make the list of Trial types (Previous was a Go...or Failed), for the
        % summary table
        SummaryTableDependency_IDs = cell(12,1);
        for jj = 1:3
        SummaryTableDependency_IDs{jj,1} = {'PreviousGo'};
        SummaryTableDependency_IDs{jj+3,1} = {'PreviousNoGo'};
        SummaryTableDependency_IDs{jj+6,1} = {'PreviousSuccess'};
        SummaryTableDependency_IDs{jj+9,1} = {'PreviousFailed'};
        end

        % Put together the Summary Table
        SummaryTableDependency = [TablePreviousTrialType; TablePreviousSuccesses];
        SummaryTableDependency = [SummaryTableDependency cell2table(SummaryTableDependency_IDs, 'VariableNames', {'PreviousTrialType'})];

        %Add last column - previous success or fail
        %preallocate tables
             TrialsListNamesSummary = cell(size(SummaryTableDependency,1),1);

        % extract full names row by row
             for ii = 1:NTrialTypes
                TrialsListNamesSummary(SummaryTableDependency.TrialType == ii,1) = TrialName(ii);
             end
             TrialsListNamesSummary(SummaryTableDependency.TrialType == 0,1) = {'Total'};
             TrialsListNamesSummary(SummaryTableDependency.TrialType == -1,1) = {'DependencyNoneExistent'};
             SummaryTableDependency.TrialType = [TrialsListNamesSummary(:,1)];

    else
        TrialDependencies(:,1:2) = {'None'};
        SummaryTableDependency = table;
    end
    
    %end function
    end

end %main function